import torch
import torch.nn.functional as F

def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
    bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
    pt = torch.exp(-bce_loss)
    return (alpha * (1 - pt) ** gamma * bce_loss).mean()

def circle_loss(logits, labels, margin=0.25, gamma=80):
    labels = labels.float()
    alpha_p = torch.clamp_min(1 + margin - logits.detach(), min=0.)
    alpha_n = torch.clamp_min(logits.detach() + margin, min=0.)
    loss = labels * alpha_p * torch.exp(-gamma * (logits - 1 + margin)) + \
           (1 - labels) * alpha_n * torch.exp(gamma * (logits + margin))
    return torch.log(1 + loss.sum())

def contrastive_loss(emb_a, emb_b, labels, margin=1.0):
    distances = (emb_a - emb_b).pow(2).sum(dim=1) 
    loss = labels * distances + (1 - labels) * F.relu(margin - distances.sqrt()).pow(2)
    return loss.mean()

def weighted_bce_loss(logits, targets, pos_weight=2.0):
    return F.binary_cross_entropy_with_logits(logits, targets, pos_weight=torch.tensor(pos_weight).to(logits.device))

def auc_margin_loss(logits, labels, margin=1.0):
    pos_scores = logits[labels == 1]
    neg_scores = logits[labels == 0]
    if len(pos_scores) == 0 or len(neg_scores) == 0:
        return torch.tensor(0.0, device=logits.device)
    loss = F.relu(margin - (pos_scores.unsqueeze(1) - neg_scores.unsqueeze(0))).mean()
    return loss

def compute_loss(loss_type, logits, labels, emb_a=None, emb_b=None):

    if loss_type == "bce":
        return F.binary_cross_entropy_with_logits(logits, labels)
    elif loss_type == "focal":
        return focal_loss(logits, labels)
    elif loss_type == "circle":
        return circle_loss(logits, labels)
    elif loss_type == "contrastive":
        assert emb_a is not None and emb_b is not None
        return contrastive_loss(emb_a, emb_b, labels)
    elif loss_type == "weighted_bce":
        return weighted_bce_loss(logits, labels)
    elif loss_type == "auc_margin":
        return auc_margin_loss(logits, labels)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")


def get_loss_func(loss_type):

    if loss_type == "bce":
        return lambda logits, labels, **_: F.binary_cross_entropy_with_logits(logits, labels)
    elif loss_type == "focal":
        return lambda logits, labels, **_: focal_loss(logits, labels)
    elif loss_type == "circle":
        return lambda logits, labels, **_: circle_loss(logits, labels)
    elif loss_type == "contrastive":
        return lambda logits, labels, emb_a, emb_b, **_: contrastive_loss(emb_a, emb_b, labels)
    elif loss_type == "weighted_bce":
        return lambda logits, labels, **_: weighted_bce_loss(logits, labels)
    elif loss_type == "auc_margin":
        return lambda logits, labels, **_: auc_margin_loss(logits, labels)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")
